package com.st.netty.handler;

import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.io.UnsupportedEncodingException;
import java.nio.charset.Charset;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

import javax.servlet.Servlet;
import javax.servlet.ServletContext;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;
import org.springframework.web.util.UriUtils;

import com.fasterxml.jackson.databind.ObjectMapper;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.DefaultHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.stream.ChunkedStream;
import io.netty.util.CharsetUtil;

/**
 *
 * @ClassName: ServletHandler
 * @Description: Servlet
 * @author wuxianwei
 *
 */
public class ServletHandler extends SimpleChannelInboundHandler<FullHttpRequest> {

	private final Logger logger = LoggerFactory.getLogger(ServletHandler.class);
	
	private final Servlet servlet;

	private final ServletContext servletContext;

	private final String charset;

	public ServletHandler(Servlet servlet, String charset) {
		this.servlet = servlet;
		this.servletContext = servlet.getServletConfig().getServletContext();
		this.charset = charset;
	}

	/**
	 * 包装HttpServletRequest
	 * 
	 * @author wuxianwei
	 * @param fullHttpRequest
	 * @return
	 */
	private MockHttpServletRequest createHttpRequest(FullHttpRequest fullHttpRequest) {
		UriComponents uriComponents = UriComponentsBuilder.fromUriString(fullHttpRequest.getUri()).build();

		MockHttpServletRequest servletRequest = new MockHttpServletRequest(this.servletContext);
		servletRequest.setRequestURI(uriComponents.getPath());
		servletRequest.setPathInfo(uriComponents.getPath());
		servletRequest.setMethod(fullHttpRequest.getMethod().name());

		
		if (uriComponents.getScheme() != null) {
			servletRequest.setScheme(uriComponents.getScheme());
		}
		if (uriComponents.getPort() != -1) {
			servletRequest.setServerPort(uriComponents.getPort());
		}
		if (uriComponents.getHost() != null) {
			servletRequest.setServerName(uriComponents.getHost());
		}

		// 请求头
		for (String name : fullHttpRequest.headers().names()) {
			servletRequest.addHeader(name, fullHttpRequest.headers().get(name));
		}

		// 请求内容
		ByteBuf content = fullHttpRequest.content();
		
		if(content.isReadable()){
			String contentStr = content.toString(CharsetUtil.UTF_8);
			servletRequest.setContent(contentStr.getBytes());
		}

		// 请求参数
		Set<Map.Entry<String, List<String>>> entrySet = uriComponents.getQueryParams().entrySet();
		entrySet.stream().forEach(entry -> {
			entry.getValue().stream().forEach(value -> {
				try{
					servletRequest.addParameter(UriUtils.decode(entry.getKey(), charset),
							UriUtils.decode(value, charset));
				}catch (UnsupportedEncodingException e) {
					logger.error(e.toString());
				}
			});
		});
		
		return servletRequest;
	}

	/**
	 * 响应错误信息
	 * 
	 * @author wuxianwei
	 * @param handlerContext
	 * @param status
	 */
	private void sendError(ChannelHandlerContext handlerContext, HttpResponseStatus status) {

		ByteBuf content = Unpooled.copiedBuffer("Failure: " + status.toString() + "\r\n", Charset.forName(charset));

		FullHttpResponse httpResponse = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, content);

		httpResponse.headers().add(HttpHeaders.Names.CONTENT_TYPE, "text/plain; charset=" + charset);

		handlerContext.write(httpResponse).addListener(ChannelFutureListener.CLOSE);
	}

	@Override
	public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
		cause.printStackTrace();
		if (ctx.channel().isActive()) {
			sendError(ctx, HttpResponseStatus.INTERNAL_SERVER_ERROR);
		}
	}

	@Override
	protected void channelRead0(ChannelHandlerContext channelHandlerContext, FullHttpRequest fullHttpRequest)
			throws Exception {
		if (fullHttpRequest.getDecoderResult().isFailure()) {
			sendError(channelHandlerContext, HttpResponseStatus.BAD_REQUEST);
			return;
		}
		MockHttpServletRequest servletRequest = createHttpRequest(fullHttpRequest);
		MockHttpServletResponse servletResponse = new MockHttpServletResponse();

		servlet.service(servletRequest, servletResponse);

		HttpResponseStatus responseStatus = HttpResponseStatus.valueOf(servletResponse.getStatus());
		HttpResponse response = new DefaultHttpResponse(HttpVersion.HTTP_1_1, responseStatus);

		for (String name : servletResponse.getHeaderNames()) {
			for (Object value : servletResponse.getHeaderValues(name)) {
				response.headers().add(name, value);
			}
		}

		byte[] responseContent = servletResponse.getContentAsByteArray();
		if (response.getStatus().code() != 200) {

			 ObjectMapper mapper = new ObjectMapper();
			 Map<String,Object> responseMap = new HashMap<String,Object>();
			 responseMap.put("status", response.getStatus().code());
			 responseMap.put("message", response.getStatus().reasonPhrase());
			 responseMap.put("data",null );
			responseContent = (mapper.writeValueAsString(responseMap)).getBytes(charset);
		}

		channelHandlerContext.write(response);

		InputStream inputStream = new ByteArrayInputStream(responseContent);

		ChannelFuture channelFuture = channelHandlerContext.writeAndFlush(new ChunkedStream(inputStream));

		channelFuture.addListener(ChannelFutureListener.CLOSE);
	}
}
